{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZloPIuRHn97X"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Authors. [Licensed under the Apache License, Version 2.0](#scrollTo=Afd8bu4xJOgh)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tNgCmfUvJNoF"
      },
      "outputs": [],
      "source": [
        "// #@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n",
        "// Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "// you may not use this file except in compliance with the License.\n",
        "// You may obtain a copy of the License at\n",
        "//\n",
        "// https://www.apache.org/licenses/LICENSE-2.0\n",
        "//\n",
        "// Unless required by applicable law or agreed to in writing, software\n",
        "// distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "// See the License for the specific language governing permissions and\n",
        "// limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AlvdCHw5JGyx"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/swift/tutorials/custom_differentiation\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/swift/blob/master/docs/site/tutorials/custom_differentiation.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/swift/blob/master/docs/site/tutorials/custom_differentiation.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c_1u7JSBMx3x"
      },
      "source": [
        "# Custom differentiation\n",
        "\n",
        "This tutorial will show you how to define your own custom derivatives, perform derivative surgery, and implement your own gradient checkpointing API in just 5 lines of Swift."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gHuQo_kCTjFr"
      },
      "source": [
        "## Declaring custom derivatives"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LP0gMw56TlvH"
      },
      "source": [
        "You can define custom derivatives for any Swift function that has differentiable parameters and results. By doing that, you can even import a C function and make it differentiable."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j0a8prgZTlEO"
      },
      "outputs": [],
      "source": [
        "import Glibc\n",
        "\n",
        "func sillyExp(_ x: Float) -> Float {\n",
        "    let 𝑒 = Float(M_E)\n",
        "    print(\"Taking 𝑒(\\(𝑒)) to the power of \\(x)!\")\n",
        "    return pow(𝑒, x)\n",
        "}\n",
        "\n",
        "@derivative(of: sillyExp)\n",
        "func sillyDerivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {\n",
        "    let y = sillyExp(x)\n",
        "    return (value: y, pullback: { v in v * y })\n",
        "}\n",
        "\n",
        "print(\"exp(3) =\", sillyExp(3))\n",
        "print(\"𝛁exp(3) =\", gradient(of: sillyExp)(3))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eQPX9r3R5OP-"
      },
      "source": [
        "## Stop derivatives from propagating\n",
        "\n",
        "Commonly known as \"stop gradient\" in machine learning use cases, method `withoutDerivative(at:)` stops derivatives from propagating.\n",
        "\n",
        "Plus, `withoutDerivative(at:)` can sometimes help the Swift compiler with identifying what not to differentiate and producing more efficient derivaitves. When it is detectable that the derivative of a function will always be zero, the Swift compiler will produce a warning. Explicitly using `withoutDerivative(at:)` silences that warning."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ctRt6vBO5Wle"
      },
      "outputs": [],
      "source": [
        "let x: Float = 2.0\n",
        "let y: Float = 3.0\n",
        "print(gradient(at: x, y) { x, y in\n",
        "    sin(sin(sin(x))) + withoutDerivative(at: cos(cos(cos(y))))\n",
        "})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EeV3wXQ79WS2"
      },
      "source": [
        "## Derivative surgery\n",
        "\n",
        "Method [`withDerivative(_:)`](https://www.tensorflow.org/swift/api_docs/Protocols/Differentiable#/s:10TensorFlow14DifferentiablePAAE12withGradientyxy15CotangentVectorQzzcF) makes arbitrary operations (including mutation) run on the gradient at a value during the enclosing function’s backpropagation. \n",
        "\n",
        "Use this to debug or make experimental tweaks to backpropagation."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AHV0ryTiD6j8"
      },
      "source": [
        "### It works anywhere"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9zKSeUjTmbxq"
      },
      "source": [
        "All differentiation APIs provided by the standard library are defined generically over all types that conform to the `Differentiable` protocol: `Float`, `Double`, `Float80`, SIMD vectors, and even your own types!\n",
        "\n",
        "Read technical document [Differentiable Types](https://github.com/tensorflow/swift/blob/master/docs/DifferentiableTypes.md) for more insights on the `Differentiable` protocol."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eKne7szjD8lr"
      },
      "outputs": [],
      "source": [
        "var x: Float = 30\n",
        "print(gradient(at: x) { x -> Float in\n",
        "    // Print the partial derivative with respect to the result of `sin(x)`.\n",
        "    let a = sin(x).withDerivative { print(\"∂+/∂sin = \\($0)\") } \n",
        "    // Force the partial derivative with respect to `x` to be `0.5`.\n",
        "    let b = log(x.withDerivative { (dx: inout Float) in\n",
        "        print(\"∂log/∂x = \\(dx), but rewritten to 0.5\");\n",
        "        dx = 0.5\n",
        "    })\n",
        "    return a + b\n",
        "})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vmw0gkqlD9xf"
      },
      "source": [
        "### Use it in a neural network module"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JCf_OplsWzhW"
      },
      "source": [
        "Just like how we used it in a simple `Float` function, we can use it in any numerical application, like the following neural network built using the [Swift for TensorFlow Deep Learning Library](https://github.com/tensorflow/swift-apis)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fnSeAbs9-hf3"
      },
      "outputs": [],
      "source": [
        "import TensorFlow\n",
        "\n",
        "struct MLP: Layer {\n",
        "    var layer1 = Dense<Float>(inputSize: 2, outputSize: 10, activation: relu)\n",
        "    var layer2 = Dense<Float>(inputSize: 10, outputSize: 1, activation: relu)\n",
        "    \n",
        "    @differentiable\n",
        "    func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {\n",
        "        let h0 = layer1(input).withDerivative { print(\"∂L/∂layer1 =\", $0) }\n",
        "        return layer2(h0)\n",
        "    }\n",
        "}\n",
        "\n",
        "var classifier = MLP()\n",
        "let optimizer = SGD(for: classifier, learningRate: 0.02)\n",
        "\n",
        "let x: Tensor<Float> = [[0, 0], [0, 1], [1, 0], [1, 1]]\n",
        "let y: Tensor<Float> = [0, 1, 1, 0]\n",
        "\n",
        "for _ in 0..<10 {\n",
        "    let 𝛁model = gradient(at: classifier) { classifier -> Tensor<Float> in\n",
        "        let ŷ = classifier(x).withDerivative { print(\"∂L/∂ŷ =\", $0) }\n",
        "        let loss = (ŷ - y).squared().mean()\n",
        "        print(\"Loss: \\(loss)\")\n",
        "        return loss\n",
        "    }\n",
        "    optimizer.update(&classifier, along: 𝛁model)\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TzLfTj28gEUD"
      },
      "source": [
        "## Recomputing activations during backpropagation to save memory (checkpointing)\n",
        "\n",
        "Checkpointing is a traditional technique in reverse-mode automatic differentiation for saving memory. Rather than saving large intermediate values in the original computation for computing derivatives, the intermediate values are instead recomputed as needed during backpropagation.\n",
        "\n",
        "This technique has been realized in modern deep learning libraries as well. In Swift, API [`withRecomputationInPullbacks(_:)`](https://www.tensorflow.org/swift/api_docs/Protocols/Differentiable#/s:10TensorFlow14DifferentiablePAAE28withRecomputationInPullbacksyqd__qd__xcAaBRd__lF) enables you to control what to recompute during backpropagation, and it is available on all `Differentiable` types.\n",
        "\n",
        "But today, let us learn how to define our own gradient checkpointing APIs from scratch, in just a few lines of code."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5cZe-JbjwMfZ"
      },
      "source": [
        "### Our gradient checkpointing API"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "606ob1dn2v77"
      },
      "source": [
        "We can define our own gradient checkpointing API, `makeRecomputedInGradient(_:)`, in terms of standard library function [`differentiableFunction(from:)`](https://www.tensorflow.org/swift/api_docs/Functions#/s:10TensorFlow22differentiableFunction4fromq0_x_q_tcq0_5value_15CotangentVectorQz_AEQy_tAEQy0_c8pullbacktx_q_tc_tAA14DifferentiableRzAaJR_AaJR0_r1_lF), which is a shorthand for creating a differentiable function directly from a derivative function (also called a \"vector-Jacobian products (VJP) function\").\n",
        "\n",
        "As we have seen before, the derivative function returns a tuple of the original function's result and a pullback closure. We return `original(x)` in `value:`, and call `pullback(at:in:)` on `original` to evaluate the original function again and get a pullback."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b1uU3tcVwl_1"
      },
      "outputs": [],
      "source": [
        "/// Given a differentiable function, returns the same differentiable function except when\n",
        "/// derivatives of this function are being computed. In that case, values in the original function needed\n",
        "/// for computing the derivatives will be recomputed, instead of being captured by the differential or pullback.\n",
        "///\n",
        "/// - Parameter body: The body of the differentiable function.\n",
        "/// - Returns: The same differentiable function whose derivatives, when computed, will recompute\n",
        "///   some values from the original function.\n",
        "func makeRecomputedInGradient<T: Differentiable, U: Differentiable>(\n",
        "    _ original: @escaping @differentiable (T) -> U\n",
        ") -> @differentiable (T) -> U {\n",
        "    return differentiableFunction { x in\n",
        "        (value: original(x), pullback: { v in pullback(at: x, in: original)(v) })\n",
        "    }\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UbeKj7NEF7zz"
      },
      "source": [
        "### Verify it works"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oee8SXital45"
      },
      "outputs": [],
      "source": [
        "let input: Float = 10.0\n",
        "print(\"Running original computation...\")\n",
        "\n",
        "// Differentiable multiplication with checkpointing.\n",
        "let square = makeRecomputedInGradient { (x: Float) -> Float in\n",
        "    print(\"  Computing square...\")\n",
        "    return x * x\n",
        "}\n",
        "\n",
        "// Differentiate `f(x) = (cos(x))^2`.\n",
        "let (output, backprop) = valueWithPullback(at: input) { input -> Float in\n",
        "    return square(cos(input))\n",
        "}\n",
        "print(\"Running backpropagation...\")\n",
        "let grad = backprop(1)\n",
        "print(\"Gradient = \\(grad)\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7SxWsSUqF9Bh"
      },
      "source": [
        "### Extend it to neural network modules\n",
        "\n",
        "In this example, we define a simple convolutional neural network.\n",
        "\n",
        "```swift\n",
        "struct Model: Layer {\n",
        "    var conv = Conv2D<Float>(filterShape: (5, 5, 3, 6))\n",
        "    var maxPool = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2))\n",
        "    var flatten = Flatten<Float>()\n",
        "    var dense = Dense<Float>(inputSize: 36 * 6, outputSize: 10)\n",
        "\n",
        "    @differentiable\n",
        "    func call(_ input: Tensor<Float>) -> Tensor<Float> {\n",
        "        return input.sequenced(through: conv, maxPool, flatten, dense)\n",
        "    }\n",
        "}\n",
        "```\n",
        "\n",
        "We want to make activations in the convolution layer (`conv`) be recomputed during backpropagation. However, using `makeRecomputedInGradient(_:)` could make the resulting code look cumbersome, especially when we want to apply layers sequentially using [`sequenced(in:through:_:_:_:_:)`](https://www.tensorflow.org/swift/api_docs/Protocols/Differentiable#/s:10TensorFlow14DifferentiablePAAE9sequenced2in7through____6OutputQyd_3_AA7ContextC_qd__qd_0_qd_1_qd_2_qd_3_t5InputQyd__RszAA5LayerRd__AaMRd_0_AaMRd_1_AaMRd_2_AaMRd_3_AKQyd_0_AGRtd__AKQyd_1_AGRtd_0_AKQyd_2_AGRtd_1_AKQyd_3_AGRtd_2_r3_lF).\n",
        "\n",
        "```swift\n",
        "input.sequenced(in: context, through: conv, maxPool, flatten, dense)\n",
        "```\n",
        "\n",
        "So, why don't we define a **special layer type** that wraps a layer and makes its activations be recomputed during backpropagation? Let's do it."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZP86M5RjP3OG"
      },
      "source": [
        "First, we define a `makeRecomputedInGradient(_:)` function that takes a binary function."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bEm-n5H0QB8s"
      },
      "outputs": [],
      "source": [
        "// Same as the previous `makeRecomputedInGradient(_:)`, except it's for binary functions.\n",
        "func makeRecomputedInGradient<T: Differentiable, U: Differentiable, V: Differentiable>(\n",
        "    _ original: @escaping @differentiable (T, U) -> V\n",
        ") -> @differentiable (T, U) -> V {\n",
        "    return differentiableFunction { x, y in\n",
        "        (value: original(x, y), pullback: { v in pullback(at: x, y, in: original)(v) })\n",
        "    }\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YU6DgqXxP5Nl"
      },
      "source": [
        "Then, we define a generic layer `ActivationDiscarding<Wrapped>`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ao1r_lIPGeOl"
      },
      "outputs": [],
      "source": [
        "import TensorFlow\n",
        "\n",
        "/// A layer wrapper that makes the underlying layer's activations be discarded during application\n",
        "/// and recomputed during backpropagation.\n",
        "struct ActivationDiscarding<Wrapped: Layer>: Layer {\n",
        "    /// The wrapped layer.\n",
        "    var wrapped: Wrapped\n",
        "\n",
        "    @differentiable\n",
        "    func callAsFunction(_ input: Wrapped.Input) -> Wrapped.Output {\n",
        "        let apply = makeRecomputedInGradient { (layer: Wrapped, input: Input) -> Wrapped.Output in\n",
        "            print(\"    Applying \\(Wrapped.self) layer...\")\n",
        "            return layer(input)\n",
        "        }\n",
        "        return apply(wrapped, input)\n",
        "    }\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HqPXwwuTRjmz"
      },
      "source": [
        "Finally, we can add a method on all layers that returns the same layer except its activations are discarded during application and recomputed during backpropagation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PGgkNnNNR1th"
      },
      "outputs": [],
      "source": [
        "extension Layer {\n",
        "    func discardingActivations() -> ActivationDiscarding<Self> {\n",
        "        return ActivationDiscarding(wrapped: self)\n",
        "    }\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8PP-NZ9XU5_n"
      },
      "source": [
        "Back in the model, all we have to change is to wrap the convolution layer into the activation-recomputing layer.\n",
        "\n",
        "```swift\n",
        "var conv = Conv2D<Float>(filterShape: (5, 5, 3, 6)).discardingActivations()\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bCwNPtCfSbGi"
      },
      "source": [
        "Now, simply use it in the model!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gsWGwFjOJ3Md"
      },
      "outputs": [],
      "source": [
        "struct Model: Layer {\n",
        "    var conv = Conv2D<Float>(filterShape: (5, 5, 3, 6)).discardingActivations()\n",
        "    var maxPool = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2))\n",
        "    var flatten = Flatten<Float>()\n",
        "    var dense = Dense<Float>(inputSize: 36 * 6, outputSize: 10)\n",
        "\n",
        "    @differentiable\n",
        "    func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {\n",
        "        return input.sequenced(through: conv, maxPool, flatten, dense)\n",
        "    }\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dmFxciU6VYdF"
      },
      "source": [
        "When we run a training loop, we can see that the convolution layer's activations are computed twice: once during layer application, and once during backpropagation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-x1nYu0uVSPn"
      },
      "outputs": [],
      "source": [
        "// Use random training data.\n",
        "let x = Tensor<Float>(randomNormal: [10, 16, 16, 3])\n",
        "let y = Tensor<Int32>(rangeFrom: 0, to: 10, stride: 1)\n",
        "\n",
        "var model = Model()\n",
        "let opt = SGD(for: model)\n",
        "\n",
        "for i in 1...5 {\n",
        "    print(\"Starting training step \\(i)\")\n",
        "    print(\"  Running original computation...\")\n",
        "    let (logits, backprop) = model.appliedForBackpropagation(to: x)\n",
        "    let (loss, dL_dŷ) = valueWithGradient(at: logits) { logits in\n",
        "        softmaxCrossEntropy(logits: logits, labels: y)\n",
        "    }\n",
        "    print(\"  Loss: \\(loss)\")\n",
        "    print(\"  Running backpropagation...\")\n",
        "    let (dL_dθ, _) = backprop(dL_dŷ)\n",
        "    \n",
        "    opt.update(&model, along: dL_dθ)\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gzRaZLa_WX0u"
      },
      "source": [
        "Just like that, it is super easy to define generic differentiable programming libraries for different domains."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "custom_differentiation.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Swift",
      "name": "swift"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
